#pragma once
#include <3rdparty/MaxFlow.hpp>
#include <Math/Random.hpp>
#include <boost/function.hpp>
#include <Math/IterExitCond.hpp>


namespace zzz{
#ifdef ZZZ_LIB_GCO
template<typename IMGT, typename GRHT=IMGT>
class ImageMultiCut
{
public:
  typedef boost::function<GRHT (const IMGT &t, const IMGT &p, zuint pos)> CalTLink;
  typedef boost::function<GRHT (const IMGT &p1, const IMGT &p2)> CalNLink;

  ImageMultiCut(CalTLink calTLink, CalNLink calNLink, const IterExitCond<GRHT> &cond)
    :CalTLink_(calTLink),CalNLink_(calNLink),cond_(cond)
  {}

  void Expansion(const Image<IMGT> &img, const vector<IMGT> &term, Array<2,zuint> &labels)
  {
    zuint labeln=term.size();

    if (labels.Size()!=img.Size())
    {
      ZLOGI << "randomly init labels\n";
      labels.SetSize(img.Size());
      RandomInteger<zuint> rand(0,labeln-1);
      for (zuint i=0; i<labels.size(); i++)
        labels[i]=rand.Rand();
    }

    //alpha_expansion
    cond_.Reset();
    int iteration=0;
    int changed;
    do {
      changed=0;
      for (zuint alpha=0;alpha<labeln;alpha++) {
        //form a new alpha cut
        MaxFlow<GRHT>::TLinks tlinks;
        //ori node
        for (zuint i=0; i<img.size(); i++) {
          if (labels[i]==alpha)
            tlinks.push_back(make_pair(CalTLink_(term[alpha],img[i],i),numeric_limits<GRHT>::max()));
          else
            tlinks.push_back(make_pair(CalTLink_(term[alpha],img[i],i),CalTLink_(term[labels[i]],img[i],i)));
        }
        //auxiliary node and NLinks
        MaxFlow<GRHT>::NLinks nlinks;
        for (zuint r=0; r<img.Rows(); r++) for (zuint c=0; c<img.Cols(); c++) {
          zuint pos1=labels.ToIndex(Vector2ui(r,c));
          if (c!=img.Cols()-1) {//right link
            zuint pos2=labels.ToIndex(Vector2ui(r,c+1));
            float alink1=CalNLink_(term[labels[pos1]],term[alpha]);
            float alink2=CalNLink_(term[alpha],term[labels[pos2]]);
            if ((labels[pos1]==alpha && labels[pos2]!=alpha) || (labels[pos1]!=alpha && labels[pos2]==alpha)) {
              //add an auxiliary node
              tlinks.push_back(make_pair(GRHT(0),CalNLink_(term[labels[pos1]],term[labels[pos2]])));
              zuint apos=tlinks.size()-1;
              nlinks.push_back(make_pair(make_pair(pos1,apos),make_pair(alink1,alink1)));
              nlinks.push_back(make_pair(make_pair(pos2,apos),make_pair(alink2,alink2)));
            } else { //normal edge
              nlinks.push_back(make_pair(make_pair(pos1,pos2),make_pair(alink1,alink2)));
            }
          }
          if (r!=img.Rows()-1) { //down link
            zuint pos2=labels.ToIndex(Vector2ui(r+1,c));
            float alink1=CalNLink_(term[labels[pos1]],term[alpha]);
            float alink2=CalNLink_(term[alpha],term[labels[pos2]]);
            if ((labels[pos1]==alpha && labels[pos2]!=alpha) || (labels[pos1]!=alpha && labels[pos2]==alpha)) {
            //add an auxiliary node
              tlinks.push_back(make_pair(GRHT(0),CalNLink_(term[labels[pos1]],term[labels[pos2]])));
              zuint apos=tlinks.size()-1;
              nlinks.push_back(make_pair(make_pair(pos1,apos),make_pair(alink1,alink1)));
              nlinks.push_back(make_pair(make_pair(pos2,apos),make_pair(alink2,alink2)));
            } else { //normal edge
              nlinks.push_back(make_pair(make_pair(pos1,pos2),make_pair(alink1,alink2)));
            }
          }
        }
        mf_.CalMaxFlow(nlinks,tlinks);
        //check if changed
        for (zuint i=0; i<img.size(); i++)
          if (mf_.InSinkSet(i) && labels[i]!=alpha) { //changed
            labels[i]=alpha;
            changed++;
          }
        ZLOGI << "alpha expansion: changed "<<changed<<" iteration: "<<iteration<<" alpha: "<<alpha<<'/'<<labeln<<endl;
      }
      iteration++;
    } while(!cond_.IsSatisfied(changed));
  }


  void Swap(const Image<IMGT> &img, const vector<GRHT> &term, Array<2,zuint> &labels)
  {
    zuint labeln=term.size();

    if (labels.Size()!=img.Size()) {
      zout<<"randomly init labels\n";
      labels.SetSize(img.Size());
      RandomInteger<zuint> rand(0,labeln-1);
      for (zuint i=0; i<labels.size(); i++)
        labels[i]=rand.Rand();
    }

    //alpha-beta swap
    int iteration=0;
    int lastchange=MAX_INT;
    int changed;
    cond_.Reset();
    do {
      changed=0;
      for (zuint alpha=0;alpha<labeln-1;alpha++) for (zuint beta=alpha+1;beta<labeln;beta++) {
        Array<2,int> oripos(img.Size());
        oripos=-1;
        //tlinks
        MaxFlow<GRHT>::TLinks tlinks;
        for (zuint i=0; i<img.size(); i++) {
          if (labels[i]!=alpha && labels[i]!=beta) continue; //keep only alpha or beta
          GRHT t_alpha=CalTLink_(term[alpha],img[i],i);
          GRHT t_beta=CalTLink_(term[beta],img[i],i);
          Vector2ui curpos=labels.ToIndex(i);
          if (curpos[0]>0) {
            zuint label=labels(Vector2ui(curpos[0]-1,curpos[1]));
            if (label!=alpha && label!=beta) {
              t_alpha+=CalNLink_(term[alpha],term[label]);
              t_beta+=CalNLink_(term[beta],term[label]);
            }
          }
          if (curpos[1]>0) {
            zuint label=labels(Vector2ui(curpos[0],curpos[1]-1));
            if (label!=alpha && label!=beta) {
              t_alpha+=CalNLink_(term[alpha],term[label]);
              t_beta+=CalNLink_(term[beta],term[label]);
            }
          }
          if (curpos[0]<img.Rows()-1) {
            zuint label=labels(Vector2ui(curpos[0]+1,curpos[1]));
            if (label!=alpha && label!=beta) {
              t_alpha+=CalNLink_(term[alpha],term[label]);
              t_beta+=CalNLink_(term[beta],term[label]);
            }
          }
          if (curpos[1]<img.Cols()-1) {
            zuint label=labels(Vector2ui(curpos[0],curpos[1]+1));
            if (label!=alpha && label!=beta) {
              t_alpha+=CalNLink_(term[alpha],term[label]);
              t_beta+=CalNLink_(term[beta],term[label]);
            }
          }
          tlinks.push_back(make_pair(t_alpha,t_beta));
          oripos[i]=tlinks.size()-1; //record ori pos
        }
        //NLinks
        MaxFlow<GRHT>::NLinks nlinks;
        float link=CalNLink_(term[labels[alpha]],term[labels[beta]]);
        for (zuint r=0; r<img.Rows(); r++) for (zuint c=0; c<img.Cols(); c++) {
          zuint pos1=labels.ToIndex(Vector2ui(r,c));
          if (labels[pos1]!=alpha && labels[pos1]!=beta) continue;
          if (c!=img.Cols()-1) { //right link
            zuint pos2=labels.ToIndex(Vector2ui(r,c+1));
            if ((labels[pos2]==alpha || labels[pos2]==beta))
              nlinks.push_back(make_pair(make_pair(zuint(oripos[pos1]),zuint(oripos[pos2])),make_pair(link,link)));
          }
          if (r!=img.Rows()-1) { //down link
            zuint pos2=labels.ToIndex(Vector2ui(r+1,c));
            if ((labels[pos2]==alpha || labels[pos2]==beta))
              nlinks.push_back(make_pair(make_pair(zuint(oripos[pos1]),zuint(oripos[pos2])),make_pair(link,link)));
          }
        }
        mf_.CalMaxFlow(nlinks,tlinks);
        //check if changed
        for (zuint i=0; i<img.size(); i++) {
          if (oripos[i]==-1) continue;
          if (mf_.InSinkSet(oripos[i]) && labels[i]!=alpha) {  //changed
            labels[i]=alpha;
            changed++;
          }
          if (mf_.InSourceSet(oripos[i]) && labels[i]!=beta) {  //changed
            labels[i]=beta;
            changed++;
          }
        }
      }
      ZLOGI<<"alpha-beta swap: round changed "<<changed<<" iteration: "<<iteration++;
    } while(!cond_.IsSatisfied(changed));
  }

private:
  CalTLink CalTLink_;
  CalNLink CalNLink_;

  MaxFlow<GRHT> mf_;

  IterExitCond<GRHT> cond_;
};
#endif
}